from PIL import Image  # 从 PIL 库中导入 Image，用于打开图像文件
from torch.utils.data import Dataset  # 导入 PyTorch 中的 Dataset 基类，用来自定义数据集


class CatsDogsDataset(Dataset):
    def __init__(self, file_list, transform=None):
        self.file_list = file_list
        self.transform = transform

    def __len__(self):
        self.file_length = len(self.file_list)
        return self.file_length

    def __getitem__(self, idx):
        img_path = self.file_list[idx]
        img = Image.open(img_path)
        img_transformed = self.transform(img)  # 将图片输入数据预处理函数

        label = img_path.split("\\")[-1].split(".")[0]  # 获取图片的标签
        label = 1 if label == "dog" else 0  # 表示 0-Dog; 1-Cat

        return img_transformed, label
